#!/usr/bin/env python3

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-ignore-all-errors

import torch
from .optimizer_args import *
from typing import Optional, List, Tuple
from fbgemm_gpu.split_table_batched_embeddings_ops_training import (
    construct_split_state,
    WeightDecayMode,
    EmbeddingLocation,
    ComputeDevice,
    apply_split_helper,
)
from torch.optim.optimizer import Optimizer
import logging

{%- if is_fbcode %}

from fbgemm_gpu.utils.loader import load_torch_module, load_torch_module_bc

def _load_library(unified_path: str, cuda_path: str, hip_path: str) -> None:
    try:
        torch.ops.load_library(unified_path)
    except Exception:
        # Load the old paths for backwards compatibility
        if torch.version.hip:
            torch.ops.load_library(hip_path)
        else:
            torch.ops.load_library(cuda_path)

load_torch_module_bc(
    "//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_training_cpu",
    "//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_cpu_training",
)

load_torch_module(
    "//deeplearning/fbgemm/fbgemm_gpu/codegen:optimizer_ops",
    "//deeplearning/fbgemm/fbgemm_gpu/codegen:optimizer_ops_cuda",
    "//deeplearning/fbgemm/fbgemm_gpu/codegen:optimizer_ops_hip",
)

load_torch_module(
    "//deeplearning/fbgemm/fbgemm_gpu:split_table_batched_embeddings",
    "//deeplearning/fbgemm/fbgemm_gpu:split_table_batched_embeddings",
    "//deeplearning/fbgemm/fbgemm_gpu:split_table_batched_embeddings_hip",
)


try:
    torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_training_gpu")
except Exception:
    if torch.version.hip:
        torch.ops.load_library(
            "//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_hip_training"
        )
    else:
        torch.ops.load_library(
            "//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_cuda_training"
        )

{%- endif %}


# Currently the template is implemented specifically for rowwise_adagrad.  It
# might or might not be applicable for other optimizers
# TODO: Add support for other optimizers
class SplitEmbedding{{ optimizer_class_name }}(Optimizer):
    def __init__(
        self,
        params: SplitEmbeddingOptimizerParams,
        embedding_args: SplitEmbeddingArgs,
        embedding_specs: List[Tuple[int, int, EmbeddingLocation, ComputeDevice]],
        feature_table_map: Optional[List[int]] = None,
        stochastic_rounding: bool = True,
        {%- if "learning_rate_tensor" in args.split_function_arg_names %}
        learning_rate: float = 0.01,
        {%- endif %}
        {%- if "eps" in args.split_function_arg_names %}
        eps: float = 1.0e-8,
        {%- endif %}
        {%- if "beta1" in args.split_function_arg_names %}
        beta1: float = 0.9,
        {%- endif %}
        {%- if "beta2" in args.split_function_arg_names %}
        beta2: float = 0.999,
        {%- endif %}
        {%- if "weight_decay" in args.split_function_arg_names %}
        weight_decay: float = 0.0,
        {%- endif %}
        {%- if "weight_decay_mode" in args.split_function_arg_names %}
        weight_decay_mode: WeightDecayMode = WeightDecayMode.NONE,
        {%- endif %}
        {%- if "eta" in args.split_function_arg_names %}
        eta: float = 0.001,
        {%- endif %}
        {%- if "momentum" in args.split_function_arg_names %}
        momentum: float = 0.9,
        {%- endif %}
    ) -> None:

        # TODO: Add arg checkers
        defaults = dict(
            {%- if "learning_rate_tensor" in args.split_function_arg_names %}
            learning_rate=learning_rate,
            {%- endif %}
            {%- if "eps" in args.split_function_arg_names %}
            eps=eps,
            {%- endif %}
            {%- if "beta1" in args.split_function_arg_names %}
            beta1=beta1,
            {%- endif %}
            {%- if "beta2" in args.split_function_arg_names %}
            beta2=beta2,
            {%- endif %}
            {%- if "weight_decay" in args.split_function_arg_names %}
            weight_decay=weight_decay,
            {%- endif %}
            {%- if "weight_decay_mode" in args.split_function_arg_names %}
            weight_decay_mode=weight_decay_mode,
            {%- endif %}
            {%- if "eta" in args.split_function_arg_names %}
            eta=eta,
            {%- endif %}
            {%- if "momentum" in args.split_function_arg_names %}
            momentum=momentum,
            {%- endif %}
        )

        super().__init__([params.weights_dev], defaults)

        # Assume rowwise_adagrad for now
        for group in self.param_groups:
            for p in group["params"]:
                if p is params.weights_dev:
                    state = self.state[p]
                    # Get device
                    device = p.device
                    assert p.is_cuda, "SplitEmbedding{{ optimizer_class_name }} only support GPU tensors"
                    # Dummy tensors for UVM and LXU cache
                    # TODO: support UVM and LXU cache
                    state["dummy_weights_uvm"] = torch.empty(
                            0,
                            device=device,
                            dtype=params.weights_dev.dtype)
                    state["dummy_lxu_cache_weights"] = torch.zeros(
                            0,
                            0,
                            device=device,
                            dtype=params.weights_dev.dtype)

                    {% for state_tensor in args.split_tensors %}
                    {% if state_tensor == "momentum1" %}
                    {% if optimizer in ["rowwise_adagrad", "rowwise_adagrad_with_counter", "rowwise_weighted_adagrad"] %}
                    rowwise = True
                    {% else %}
                    rowwise = False
                    {% endif %}
                    {% elif state_tensor == "momentum2" %}
                    {% if optimizer in ["partial_rowwise_adam", "partial_rowwise_lamb"] %}
                    rowwise = True
                    {% else %}
                    rowwise = False
                    {% endif %}
                    {% else %}
                    # row_counter or prev_iter
                    rowwise = True
                    {% endif %}
                    # Move this to the global state (if possible)
                    # Construct {{ state_tensor }} state
                    {{ state_tensor }}_state = construct_split_state(
                            embedding_specs,
                            rowwise=rowwise,
                            cacheable=False)

                    def set_state(name: str, tensor: torch.Tensor) -> None:
                        state[name] = tensor

                    apply_split_helper(
                        persistent_state_fn=set_state,
                        set_attr_fn=set_state,
                        current_device=device,
                        use_cpu=False,
                        feature_table_map=feature_table_map,
                        split={{ state_tensor }}_state,
                        prefix="{{ state_tensor }}",
                        dtype=torch.float32,
                    )
                    {% endfor %}
                else:
                    raise NotImplementedError(
                            "SplitEmbedding{{ optimizer_class_name }} only supports weights_dev update")

        self.params = params
        self.embedding_args = embedding_args
        self.stochastic_rounding = stochastic_rounding

        {%- if "learning_rate_tensor" in args.split_function_arg_names %}
        self.learning_rate_tensor = torch.tensor(
            learning_rate, device=torch.device("cpu"), dtype=torch.float
        )
        {%- endif %}
        {%- if "eps" in args.split_function_arg_names %}
        self.eps = eps
        {%- endif %}
        {%- if "beta1" in args.split_function_arg_names %}
        self.beta1 = beta1
        {%- endif %}
        {%- if "beta2" in args.split_function_arg_names %}
        self.beta2 = beta2
        {%- endif %}
        {%- if "weight_decay" in args.split_function_arg_names %}
        self.weight_decay = weight_decay
        {%- endif %}
        {%- if "weight_decay_mode" in args.split_function_arg_names %}
        self.weight_decay_mode = weight_decay_mode
        {%- endif %}
        {%- if "eta" in args.split_function_arg_names %}
        self.eta = eta
        {%- endif %}
        {%- if "momentum" in args.split_function_arg_names %}
        self.momentum = momentum
        {%- endif %}


    def step(self, closure=None) -> torch.Tensor:
        for group in self.param_groups:
            for p in group["params"]:
                if p is self.params.weights_dev:
                    state = self.state[p]
                    sparse_grad_dev_weights = p.grad
                    # Dummy tensors for UVM and LXU cache
                    dummy_weights_uvm = state["dummy_weights_uvm"]
                    dummy_lxu_cache_weights = state["dummy_lxu_cache_weights"]
                    {% for state_tensor in args.split_tensors %}
                    # {{ state_tensor }} state
                    {{ state_tensor }}_dev = state["{{ state_tensor }}_dev"]
                    {{ state_tensor }}_uvm = state["{{ state_tensor }}_uvm"]
                    {{ state_tensor }}_offsets = state["{{ state_tensor }}_offsets"]
                    {{ state_tensor }}_placements = state["{{ state_tensor }}_placements"]
                    {% endfor %}
                    assert sparse_grad_dev_weights.is_sparse, "sparse_grad_dev_weights must be sparse"
                else:
                    raise NotImplementedError(
                            "SplitEmbedding{{ optimizer_class_name }} only supports weights_dev update")

        # Call {{ optimizer }}
        torch.ops.fbgemm.split_embedding_{{ optimizer }}_update(
                dev_weights=self.params.weights_dev,
                uvm_weights=dummy_weights_uvm,
                lxu_cache_weights=dummy_lxu_cache_weights,
                grad_dev_weights=sparse_grad_dev_weights._values(),
                grad_dev_indices=sparse_grad_dev_weights._indices(),
                weights_placement=self.embedding_args.weights_placements,
                weights_offsets=self.embedding_args.weights_offsets,
                max_D=self.embedding_args.max_D,
                stochastic_rounding=self.stochastic_rounding,
                {%- if "learning_rate_tensor" in args.split_function_arg_names %}
                learning_rate_tensor=self.learning_rate_tensor,
                {%- endif %}
                {%- if "eps" in args.split_function_arg_names %}
                eps=self.eps,
                {%- endif %}
                {%- if "beta1" in args.split_function_arg_names %}
                beta1=self.beta1,
                {%- endif %}
                {%- if "beta2" in args.split_function_arg_names %}
                beta2=self.beta2,
                {%- endif %}
                {%- if "weight_decay" in args.split_function_arg_names %}
                weight_decay=self.weight_decay,
                {%- endif %}
                {%- if "weight_decay_mode" in args.split_function_arg_names %}
                weight_decay_mode=self.weight_decay_mode,
                {%- endif %}
                {%- if "eta" in args.split_function_arg_names %}
                eta=self.eta,
                {%- endif %}
                {%- if "momentum" in args.split_function_arg_names %}
                momentum=self.momentum,
                {%- endif %}
                {% for state_tensor in args.split_tensors %}
                # {{ state_tensor }}
                {{ state_tensor }}_dev={{ state_tensor }}_dev,
                {{ state_tensor }}_uvm={{ state_tensor }}_uvm,
                {{ state_tensor }}_placements={{ state_tensor }}_placements,
                {{ state_tensor }}_offsets={{ state_tensor }}_offsets,
                {% endfor %}
        )
